from sklearn import tree
import pandas as pd
import numpy as np
df_bp = pd.read_csv("bloodpressure_train.csv")
df_bp.head()
| id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 30669 | Male | 3.0 | 0 | 0 | No | children | Rural | 95.12 | 18.0 | NaN | 0 |
| 1 | 30468 | Male | 58.0 | 1 | 0 | Yes | Private | Urban | 87.96 | 39.2 | never smoked | 0 |
| 2 | 16523 | Female | 8.0 | 0 | 0 | No | Private | Urban | 110.89 | 17.6 | NaN | 0 |
| 3 | 56543 | Female | 70.0 | 0 | 0 | Yes | Private | Rural | 69.04 | 35.9 | formerly smoked | 0 |
| 4 | 46136 | Male | 14.0 | 0 | 0 | No | Never_worked | Rural | 161.28 | 19.1 | NaN | 0 |
target = df_bp["stroke"]
target_names = ["negative", "positive"]
df_bp['gender'].replace('Male',1, inplace=True)
df_bp['gender'].replace('Female',0, inplace=True)
df_bp['gender'].replace('Other',0,inplace=True)
df_bp['ever_married'].replace('Yes',1, inplace=True)
df_bp['ever_married'].replace('No',0, inplace=True)
df_bp['Residence_type'].replace('Urban',1, inplace=True)
df_bp['Residence_type'].replace('Rural',0, inplace=True)
df_bp['smoking_status'].replace('smokes',1, inplace=True)
df_bp['smoking_status'].replace('never smoked',0, inplace=True)
df_bp['smoking_status'].replace('formerly smoked',1, inplace=True)
df_bp.head()
| id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 30669 | 1 | 3.0 | 0 | 0 | 0 | children | 0 | 95.12 | 18.0 | NaN | 0 |
| 1 | 30468 | 1 | 58.0 | 1 | 0 | 1 | Private | 1 | 87.96 | 39.2 | 0.0 | 0 |
| 2 | 16523 | 0 | 8.0 | 0 | 0 | 0 | Private | 1 | 110.89 | 17.6 | NaN | 0 |
| 3 | 56543 | 0 | 70.0 | 0 | 0 | 1 | Private | 0 | 69.04 | 35.9 | 1.0 | 0 |
| 4 | 46136 | 1 | 14.0 | 0 | 0 | 0 | Never_worked | 0 | 161.28 | 19.1 | NaN | 0 |
BPdata = df_bp.drop(["id","work_type", "stroke"], axis=1)
feature_names = BPdata.columns
BPdata.head()
| gender | age | hypertension | heart_disease | ever_married | Residence_type | avg_glucose_level | bmi | smoking_status | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 3.0 | 0 | 0 | 0 | 0 | 95.12 | 18.0 | NaN |
| 1 | 1 | 58.0 | 1 | 0 | 1 | 1 | 87.96 | 39.2 | 0.0 |
| 2 | 0 | 8.0 | 0 | 0 | 0 | 1 | 110.89 | 17.6 | NaN |
| 3 | 0 | 70.0 | 0 | 0 | 1 | 0 | 69.04 | 35.9 | 1.0 |
| 4 | 1 | 14.0 | 0 | 0 | 0 | 0 | 161.28 | 19.1 | NaN |
BPdata.fillna(0)
| gender | age | hypertension | heart_disease | ever_married | Residence_type | avg_glucose_level | bmi | smoking_status | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 3.0 | 0 | 0 | 0 | 0 | 95.12 | 18.0 | 0.0 |
| 1 | 1 | 58.0 | 1 | 0 | 1 | 1 | 87.96 | 39.2 | 0.0 |
| 2 | 0 | 8.0 | 0 | 0 | 0 | 1 | 110.89 | 17.6 | 0.0 |
| 3 | 0 | 70.0 | 0 | 0 | 1 | 0 | 69.04 | 35.9 | 1.0 |
| 4 | 1 | 14.0 | 0 | 0 | 0 | 0 | 161.28 | 19.1 | 0.0 |
| 5 | 0 | 47.0 | 0 | 0 | 1 | 1 | 210.95 | 50.1 | 0.0 |
| 6 | 0 | 52.0 | 0 | 0 | 1 | 1 | 77.59 | 17.7 | 1.0 |
| 7 | 0 | 75.0 | 0 | 1 | 1 | 0 | 243.53 | 27.0 | 0.0 |
| 8 | 0 | 32.0 | 0 | 0 | 1 | 0 | 77.67 | 32.3 | 1.0 |
| 9 | 0 | 74.0 | 1 | 0 | 1 | 1 | 205.84 | 54.6 | 0.0 |
| 10 | 0 | 79.0 | 0 | 0 | 1 | 1 | 77.08 | 35.0 | 0.0 |
| 11 | 1 | 79.0 | 0 | 1 | 1 | 1 | 57.08 | 22.0 | 1.0 |
| 12 | 0 | 37.0 | 0 | 0 | 1 | 0 | 162.96 | 39.4 | 0.0 |
| 13 | 0 | 37.0 | 0 | 0 | 1 | 0 | 73.50 | 26.1 | 1.0 |
| 14 | 0 | 40.0 | 0 | 0 | 1 | 0 | 95.04 | 42.4 | 0.0 |
| 15 | 1 | 35.0 | 0 | 0 | 0 | 0 | 85.37 | 33.0 | 0.0 |
| 16 | 0 | 20.0 | 0 | 0 | 0 | 1 | 84.62 | 19.7 | 1.0 |
| 17 | 0 | 42.0 | 0 | 0 | 1 | 0 | 82.67 | 22.5 | 0.0 |
| 18 | 0 | 44.0 | 0 | 0 | 1 | 1 | 57.33 | 24.6 | 1.0 |
| 19 | 0 | 79.0 | 0 | 1 | 1 | 1 | 67.84 | 25.2 | 1.0 |
| 20 | 0 | 65.0 | 1 | 0 | 1 | 0 | 75.70 | 41.8 | 0.0 |
| 21 | 0 | 57.0 | 1 | 0 | 1 | 0 | 129.54 | 60.9 | 1.0 |
| 22 | 0 | 49.0 | 0 | 0 | 1 | 0 | 60.22 | 31.5 | 1.0 |
| 23 | 1 | 71.0 | 0 | 0 | 1 | 1 | 198.21 | 27.3 | 1.0 |
| 24 | 0 | 59.0 | 0 | 0 | 1 | 1 | 109.82 | 23.7 | 0.0 |
| 25 | 0 | 25.0 | 0 | 0 | 1 | 1 | 60.84 | 24.5 | 0.0 |
| 26 | 0 | 67.0 | 0 | 0 | 1 | 0 | 94.61 | 28.4 | 1.0 |
| 27 | 0 | 38.0 | 0 | 0 | 0 | 0 | 97.49 | 26.9 | 0.0 |
| 28 | 0 | 54.0 | 0 | 0 | 1 | 0 | 206.72 | 26.7 | 0.0 |
| 29 | 0 | 70.0 | 0 | 0 | 1 | 0 | 214.45 | 31.2 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 43370 | 1 | 79.0 | 0 | 0 | 1 | 0 | 100.71 | 25.2 | 0.0 |
| 43371 | 0 | 78.0 | 0 | 0 | 1 | 1 | 182.45 | 0.0 | 1.0 |
| 43372 | 1 | 36.0 | 1 | 0 | 1 | 1 | 149.22 | 33.4 | 0.0 |
| 43373 | 0 | 46.0 | 0 | 0 | 0 | 0 | 88.66 | 34.5 | 1.0 |
| 43374 | 1 | 49.0 | 0 | 0 | 1 | 0 | 98.91 | 35.3 | 1.0 |
| 43375 | 0 | 70.0 | 0 | 0 | 1 | 0 | 121.45 | 27.3 | 0.0 |
| 43376 | 0 | 3.0 | 0 | 0 | 0 | 0 | 83.32 | 21.3 | 0.0 |
| 43377 | 0 | 63.0 | 0 | 0 | 1 | 1 | 88.08 | 33.6 | 0.0 |
| 43378 | 1 | 22.0 | 0 | 0 | 0 | 1 | 86.82 | 30.2 | 0.0 |
| 43379 | 1 | 47.0 | 0 | 1 | 1 | 1 | 89.25 | 29.0 | 0.0 |
| 43380 | 1 | 26.0 | 0 | 0 | 1 | 0 | 71.31 | 25.3 | 1.0 |
| 43381 | 1 | 45.0 | 0 | 0 | 1 | 1 | 214.05 | 40.5 | 1.0 |
| 43382 | 0 | 9.0 | 0 | 0 | 0 | 1 | 68.49 | 16.8 | 0.0 |
| 43383 | 1 | 18.0 | 0 | 0 | 0 | 1 | 131.73 | 24.9 | 0.0 |
| 43384 | 0 | 65.0 | 0 | 0 | 1 | 0 | 200.92 | 30.7 | 1.0 |
| 43385 | 0 | 66.0 | 0 | 0 | 1 | 1 | 92.10 | 24.8 | 0.0 |
| 43386 | 1 | 68.0 | 0 | 1 | 1 | 1 | 113.60 | 25.5 | 0.0 |
| 43387 | 1 | 20.0 | 0 | 0 | 0 | 0 | 83.37 | 26.5 | 0.0 |
| 43388 | 0 | 64.0 | 1 | 0 | 1 | 0 | 228.43 | 0.0 | 1.0 |
| 43389 | 1 | 14.0 | 0 | 0 | 0 | 1 | 82.48 | 24.8 | 0.0 |
| 43390 | 0 | 69.0 | 0 | 0 | 1 | 1 | 229.85 | 31.2 | 0.0 |
| 43391 | 1 | 6.0 | 0 | 0 | 0 | 1 | 77.48 | 19.1 | 0.0 |
| 43392 | 0 | 18.0 | 0 | 0 | 0 | 1 | 131.96 | 22.8 | 0.0 |
| 43393 | 1 | 39.0 | 0 | 0 | 1 | 0 | 132.22 | 31.6 | 0.0 |
| 43394 | 1 | 47.0 | 0 | 0 | 0 | 1 | 68.52 | 25.2 | 1.0 |
| 43395 | 0 | 10.0 | 0 | 0 | 0 | 1 | 58.64 | 20.4 | 0.0 |
| 43396 | 0 | 56.0 | 0 | 0 | 1 | 1 | 213.61 | 55.4 | 1.0 |
| 43397 | 0 | 82.0 | 1 | 0 | 1 | 1 | 91.94 | 28.9 | 1.0 |
| 43398 | 1 | 40.0 | 0 | 0 | 1 | 1 | 99.16 | 33.2 | 0.0 |
| 43399 | 0 | 82.0 | 0 | 0 | 1 | 1 | 79.48 | 20.6 | 0.0 |
43400 rows × 9 columns
BPdata.astype('float')
| gender | age | hypertension | heart_disease | ever_married | Residence_type | avg_glucose_level | bmi | smoking_status | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 95.12 | 18.0 | NaN |
| 1 | 1.0 | 58.0 | 1.0 | 0.0 | 1.0 | 1.0 | 87.96 | 39.2 | 0.0 |
| 2 | 0.0 | 8.0 | 0.0 | 0.0 | 0.0 | 1.0 | 110.89 | 17.6 | NaN |
| 3 | 0.0 | 70.0 | 0.0 | 0.0 | 1.0 | 0.0 | 69.04 | 35.9 | 1.0 |
| 4 | 1.0 | 14.0 | 0.0 | 0.0 | 0.0 | 0.0 | 161.28 | 19.1 | NaN |
| 5 | 0.0 | 47.0 | 0.0 | 0.0 | 1.0 | 1.0 | 210.95 | 50.1 | NaN |
| 6 | 0.0 | 52.0 | 0.0 | 0.0 | 1.0 | 1.0 | 77.59 | 17.7 | 1.0 |
| 7 | 0.0 | 75.0 | 0.0 | 1.0 | 1.0 | 0.0 | 243.53 | 27.0 | 0.0 |
| 8 | 0.0 | 32.0 | 0.0 | 0.0 | 1.0 | 0.0 | 77.67 | 32.3 | 1.0 |
| 9 | 0.0 | 74.0 | 1.0 | 0.0 | 1.0 | 1.0 | 205.84 | 54.6 | 0.0 |
| 10 | 0.0 | 79.0 | 0.0 | 0.0 | 1.0 | 1.0 | 77.08 | 35.0 | NaN |
| 11 | 1.0 | 79.0 | 0.0 | 1.0 | 1.0 | 1.0 | 57.08 | 22.0 | 1.0 |
| 12 | 0.0 | 37.0 | 0.0 | 0.0 | 1.0 | 0.0 | 162.96 | 39.4 | 0.0 |
| 13 | 0.0 | 37.0 | 0.0 | 0.0 | 1.0 | 0.0 | 73.50 | 26.1 | 1.0 |
| 14 | 0.0 | 40.0 | 0.0 | 0.0 | 1.0 | 0.0 | 95.04 | 42.4 | 0.0 |
| 15 | 1.0 | 35.0 | 0.0 | 0.0 | 0.0 | 0.0 | 85.37 | 33.0 | 0.0 |
| 16 | 0.0 | 20.0 | 0.0 | 0.0 | 0.0 | 1.0 | 84.62 | 19.7 | 1.0 |
| 17 | 0.0 | 42.0 | 0.0 | 0.0 | 1.0 | 0.0 | 82.67 | 22.5 | 0.0 |
| 18 | 0.0 | 44.0 | 0.0 | 0.0 | 1.0 | 1.0 | 57.33 | 24.6 | 1.0 |
| 19 | 0.0 | 79.0 | 0.0 | 1.0 | 1.0 | 1.0 | 67.84 | 25.2 | 1.0 |
| 20 | 0.0 | 65.0 | 1.0 | 0.0 | 1.0 | 0.0 | 75.70 | 41.8 | NaN |
| 21 | 0.0 | 57.0 | 1.0 | 0.0 | 1.0 | 0.0 | 129.54 | 60.9 | 1.0 |
| 22 | 0.0 | 49.0 | 0.0 | 0.0 | 1.0 | 0.0 | 60.22 | 31.5 | 1.0 |
| 23 | 1.0 | 71.0 | 0.0 | 0.0 | 1.0 | 1.0 | 198.21 | 27.3 | 1.0 |
| 24 | 0.0 | 59.0 | 0.0 | 0.0 | 1.0 | 1.0 | 109.82 | 23.7 | 0.0 |
| 25 | 0.0 | 25.0 | 0.0 | 0.0 | 1.0 | 1.0 | 60.84 | 24.5 | 0.0 |
| 26 | 0.0 | 67.0 | 0.0 | 0.0 | 1.0 | 0.0 | 94.61 | 28.4 | 1.0 |
| 27 | 0.0 | 38.0 | 0.0 | 0.0 | 0.0 | 0.0 | 97.49 | 26.9 | 0.0 |
| 28 | 0.0 | 54.0 | 0.0 | 0.0 | 1.0 | 0.0 | 206.72 | 26.7 | 0.0 |
| 29 | 0.0 | 70.0 | 0.0 | 0.0 | 1.0 | 0.0 | 214.45 | 31.2 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 43370 | 1.0 | 79.0 | 0.0 | 0.0 | 1.0 | 0.0 | 100.71 | 25.2 | 0.0 |
| 43371 | 0.0 | 78.0 | 0.0 | 0.0 | 1.0 | 1.0 | 182.45 | NaN | 1.0 |
| 43372 | 1.0 | 36.0 | 1.0 | 0.0 | 1.0 | 1.0 | 149.22 | 33.4 | NaN |
| 43373 | 0.0 | 46.0 | 0.0 | 0.0 | 0.0 | 0.0 | 88.66 | 34.5 | 1.0 |
| 43374 | 1.0 | 49.0 | 0.0 | 0.0 | 1.0 | 0.0 | 98.91 | 35.3 | 1.0 |
| 43375 | 0.0 | 70.0 | 0.0 | 0.0 | 1.0 | 0.0 | 121.45 | 27.3 | 0.0 |
| 43376 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 83.32 | 21.3 | NaN |
| 43377 | 0.0 | 63.0 | 0.0 | 0.0 | 1.0 | 1.0 | 88.08 | 33.6 | 0.0 |
| 43378 | 1.0 | 22.0 | 0.0 | 0.0 | 0.0 | 1.0 | 86.82 | 30.2 | 0.0 |
| 43379 | 1.0 | 47.0 | 0.0 | 1.0 | 1.0 | 1.0 | 89.25 | 29.0 | NaN |
| 43380 | 1.0 | 26.0 | 0.0 | 0.0 | 1.0 | 0.0 | 71.31 | 25.3 | 1.0 |
| 43381 | 1.0 | 45.0 | 0.0 | 0.0 | 1.0 | 1.0 | 214.05 | 40.5 | 1.0 |
| 43382 | 0.0 | 9.0 | 0.0 | 0.0 | 0.0 | 1.0 | 68.49 | 16.8 | NaN |
| 43383 | 1.0 | 18.0 | 0.0 | 0.0 | 0.0 | 1.0 | 131.73 | 24.9 | 0.0 |
| 43384 | 0.0 | 65.0 | 0.0 | 0.0 | 1.0 | 0.0 | 200.92 | 30.7 | 1.0 |
| 43385 | 0.0 | 66.0 | 0.0 | 0.0 | 1.0 | 1.0 | 92.10 | 24.8 | NaN |
| 43386 | 1.0 | 68.0 | 0.0 | 1.0 | 1.0 | 1.0 | 113.60 | 25.5 | 0.0 |
| 43387 | 1.0 | 20.0 | 0.0 | 0.0 | 0.0 | 0.0 | 83.37 | 26.5 | 0.0 |
| 43388 | 0.0 | 64.0 | 1.0 | 0.0 | 1.0 | 0.0 | 228.43 | NaN | 1.0 |
| 43389 | 1.0 | 14.0 | 0.0 | 0.0 | 0.0 | 1.0 | 82.48 | 24.8 | NaN |
| 43390 | 0.0 | 69.0 | 0.0 | 0.0 | 1.0 | 1.0 | 229.85 | 31.2 | 0.0 |
| 43391 | 1.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 77.48 | 19.1 | NaN |
| 43392 | 0.0 | 18.0 | 0.0 | 0.0 | 0.0 | 1.0 | 131.96 | 22.8 | NaN |
| 43393 | 1.0 | 39.0 | 0.0 | 0.0 | 1.0 | 0.0 | 132.22 | 31.6 | 0.0 |
| 43394 | 1.0 | 47.0 | 0.0 | 0.0 | 0.0 | 1.0 | 68.52 | 25.2 | 1.0 |
| 43395 | 0.0 | 10.0 | 0.0 | 0.0 | 0.0 | 1.0 | 58.64 | 20.4 | 0.0 |
| 43396 | 0.0 | 56.0 | 0.0 | 0.0 | 1.0 | 1.0 | 213.61 | 55.4 | 1.0 |
| 43397 | 0.0 | 82.0 | 1.0 | 0.0 | 1.0 | 1.0 | 91.94 | 28.9 | 1.0 |
| 43398 | 1.0 | 40.0 | 0.0 | 0.0 | 1.0 | 1.0 | 99.16 | 33.2 | 0.0 |
| 43399 | 0.0 | 82.0 | 0.0 | 0.0 | 1.0 | 1.0 | 79.48 | 20.6 | 0.0 |
43400 rows × 9 columns
print(BPdata.dtypes)
gender int64 age float64 hypertension int64 heart_disease int64 ever_married int64 Residence_type int64 avg_glucose_level float64 bmi float64 smoking_status float64 dtype: object
BPdata.fillna(0,inplace=True)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(BPdata, target, random_state=42)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)
clf.score(X_test, y_test)
0.9626728110599079
# WARNING! BOILERPLATE CODE HERE!
# Use this to visualize the tree
import graphviz
dot_data = tree.export_graphviz(
clf, out_file=None,
feature_names=feature_names,
class_names=target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph
import pydotplus
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('randomforest.png')
graph = graphviz.Source(dot_data)
graph
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(n_estimators=200)
rf = rf.fit(X_train, y_train)
rf.score(X_test, y_test)
0.9802764976958526
sorted(zip(rf.feature_importances_, feature_names), reverse=True)
[(0.39346978499713897, 'avg_glucose_level'), (0.27548411711469956, 'bmi'), (0.19299537341010364, 'age'), (0.030943686214760686, 'Residence_type'), (0.029839428898182706, 'gender'), (0.026374686960822822, 'smoking_status'), (0.020392386361415906, 'hypertension'), (0.016433023961452426, 'heart_disease'), (0.014067512081423544, 'ever_married')]